from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import load_image
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = "mllm_mapper"

torch = LazyLoader("torch")
transformers = LazyLoader("transformers")


@LOADED_IMAGES.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class MllmMapper(Mapper):
    """Mapper to use MLLMs for visual question answering tasks. This operator uses a Hugging
    Face model to generate answers based on input text and images. It supports models like
    `llava-hf/llava-v1.6-vicuna-7b-hf` and `Qwen/Qwen2-VL-7B-Instruct`. The operator
    processes each sample, loading and processing images, and generating responses using the
    specified model. The generated responses are appended to the sample's text field. The
    key parameters include the model ID, maximum new tokens, temperature, top-p sampling,
    and beam search size, which control the generation process."""

    _accelerator = "cuda"

    def __init__(
        self,
        hf_model: str = "llava-hf/llava-v1.6-vicuna-7b-hf",
        max_new_tokens=256,
        temperature=0.2,
        top_p=None,
        num_beams=1,
        *args,
        **kwargs,
    ):
        """
        Initialization method.
        :param hf_model: hugginface model id.
        :param max_new_tokens: the maximum number of new tokens
            generated by the model.
        :param temperature: used to control the randomness of \
            generated text. The higher the temperature, the more \
                random and creative the generated text will be.
        :param top_p: randomly select the next word from the group \
            of words whose cumulative probability reaches p.
        :param num_beams: the larger the beam search size, the higher \
            the quality of the generated text.
        :param args: extra args
        :param kwargs: extra args
        """
        torch.set_num_threads(1)

        kwargs["mem_required"] = "32GB" if kwargs.get("mem_required", 0) == 0 else kwargs["mem_required"]
        kwargs["num_proc"] = 1 if kwargs.get("num_proc", None) is None else kwargs["num_proc"]
        super().__init__(*args, **kwargs)

        self.hf_model = hf_model
        self.model_key = prepare_model(model_type="huggingface", pretrained_model_name_or_path=hf_model)
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.num_beams = num_beams

    def process_single(self, sample=None, rank=None):
        # there is no image in this sample
        if self.image_key not in sample or not sample[self.image_key]:
            return sample

        # load images
        loaded_image_keys = sample[self.image_key]
        images = {}
        for loaded_image_key in loaded_image_keys:
            if loaded_image_key not in images:
                # avoid loading the same images
                image = load_image(loaded_image_key)
                images[loaded_image_key] = image

        model, processor = get_model(model_key=self.model_key, rank=rank, use_cuda=self.use_cuda())

        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": sample[self.text_key]},
                    {"type": "image"},
                ],
            },
        ]
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

        sample[self.text_key] = []

        for image_key in images:
            inputs = processor(images=images[image_key], text=prompt, return_tensors="pt").to(model.device)

            response = model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                num_beams=self.num_beams,
            )

            output = processor.decode(response.cpu()[0], skip_special_tokens=True)

            sample[self.text_key].append(output)

        return sample
